Add Python distributed L4 to L3 dispatch#711
Add Python distributed L4 to L3 dispatch#711PKUZHOU wants to merge 6 commits intohw-native-sys:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a distributed L4 to L3 dispatch system using gRPC and protobuf, enabling cross-host task execution. It introduces a callable catalog for remote function registration, a long-running L3 daemon with a fork-safe backend process, and a mailbox shim thread to integrate remote workers into the existing C++ scheduler. Feedback highlights critical issues regarding the transmission of raw memory pointers across host boundaries, which would lead to segmentation faults. Other recommendations include removing redundant logic in the catalog registration, ensuring consistent use of cloudpickle for deserialization, and improving error handling for unexpected backend process terminations.
| tag = args.tag(i) | ||
| tensors.append( | ||
| dispatch_pb2.ContinuousTensorRef( | ||
| data=int(tensor.data), |
There was a problem hiding this comment.
Sending raw memory pointers (tensor.data) across host boundaries in a distributed system is incorrect. These addresses are local to the L4 process and will be invalid on the remote L3 node, leading to segmentation faults if accessed. A position-independent mechanism, such as handles or offsets into a shared tensor pool, should be used instead.
References
- To ensure shared memory is position-independent for future cross-process/cross-address-space communication, avoid storing absolute pointers (to stack or heap) within shared memory structures. Use relative offsets or process-local handles instead.
| shape = tuple(int(x) for x in ref.shape) | ||
| dtype = DataType(int(ref.dtype)) | ||
| tag = TensorArgType(int(ref.tag)) | ||
| args.add_tensor(ContinuousTensor.make(int(ref.data), shape, dtype), tag) |
There was a problem hiding this comment.
Reconstructing a ContinuousTensor from a raw pointer received over the network is dangerous. In a distributed environment, ref.data is a pointer from a different address space (and likely a different host), making it invalid for local use.
References
- To ensure shared memory is position-independent for future cross-process/cross-address-space communication, avoid storing absolute pointers (to stack or heap) within shared memory structures. Use relative offsets or process-local handles instead.
| if callable_id is None: | ||
| self._next_id = max(self._next_id, cid + 1) | ||
| else: | ||
| self._next_id = max(self._next_id, cid + 1) |
There was a problem hiding this comment.
The self._next_id update is redundant here because self.install_from_payload(cid, version, payload) (called on line 39) already performs the exact same max(self._next_id, cid + 1) update on line 67. Additionally, the if/else branches are identical.
References
- Reuse existing helper functions or methods instead of duplicating their logic. This improves consistency, maintainability, and reduces the chance of introducing bugs.
|
|
||
| def _loads_with_allowlist(payload: bytes, allowed_modules: Optional[Tuple[str, ...]]) -> Callable: | ||
| if allowed_modules is None: | ||
| return pickle.loads(payload) |
There was a problem hiding this comment.
Since cloudpickle (aliased as _pickle_impl) is used for serialization in register, it should also be used for deserialization here to ensure compatibility, especially for lambdas and closures which standard pickle cannot handle.
| return pickle.loads(payload) | |
| return _pickle_impl.loads(payload) |
| with self._backend_lock: | ||
| self._backend_conn.send(msg) | ||
| ok, payload = self._backend_conn.recv() |
There was a problem hiding this comment.
self._backend_conn.recv() will raise an EOFError if the backend process terminates unexpectedly (e.g., due to a crash). This should be handled to provide a more descriptive error message to the RPC client rather than letting the gRPC handler thread fail with an unhandled exception.
| with self._backend_lock: | |
| self._backend_conn.send(msg) | |
| ok, payload = self._backend_conn.recv() | |
| try: | |
| with self._backend_lock: | |
| self._backend_conn.send(msg) | |
| ok, payload = self._backend_conn.recv() | |
| except EOFError: | |
| raise RuntimeError("L3 daemon backend process terminated unexpectedly") from None |
| def sleep_poll_interval() -> None: | ||
| time.sleep(0.0005) |
|
counter = Counter()
def l3_sub(task_args):
counter.add(task_args.scalar(0))
...
print(f"remote counter={counter.value}")
return 0 if counter.value == 7 else 1But Meanwhile, the current dispatch response only reports success/failure: inner.run(orch_fn, args, cfg)
return dispatch_pb2.DispatchResp(task_id=req.task_id, error_code=0), innerand The tests ( The example should either avoid expecting L4-local closure state to change, or explicitly use/document an external side effect until |
uv-xiao
left a comment
There was a problem hiding this comment.
Implementation review summary for the L4-L3 distributed dispatch PR. I focused on behavioral and semantic issues rather than CI/style.
Main concerns:
- Callable catalog versioning is internally inconsistent:
PullCallable(version=0)resolves latest payload bytes but returns version 0, which can fail install-time version validation. - The documented L4/L3 example constructs
TaskArgsin an invalid scalar-before-tensor order. - The dispatch proto exposes both legacy address-like
tensor_argsand remote data-planetensor_refs; the contract should be narrowed or validated so raw L4 addresses are not accidentally treated as remote execution pointers. - Tensor input staging is currently selected by an internal byte threshold. The distributed L4 programming model should prefer an explicit handle/remote-storage path chosen by the program, with inline bytes only as an explicit small-message/test path.
OUTPUTandOUTPUT_EXISTINGstaging reads and transfers old local buffer bytes even though output-only tensors do not semantically consume prior contents.INOUTcurrently has copy-in/copy-out semantics and is excluded from the RXE local-output fast path, so it should not be described as shared or in-place remote memory.- Tensor-ref dispatch creates an ephemeral backend
Worker(level=3)per request, which differs materially from persistent local L3 worker reuse and needs rationale/benchmarking or a plan for persistent child-visible tensor storage. - L3 daemon dispatch goes through serialized blocking foreground-to-backend IPC after the gRPC call; the process split and overhead need to be justified or simplified.
- Heartbeat checks only foreground gRPC liveness, not backend/runtime/device/TensorPool readiness.
- Remote callables use cloudpickle/pickle, whose semantics differ from local fork/COW callable inheritance; the remote callable contract and trusted-cluster assumption should be explicit.
| context.abort(grpc.StatusCode.NOT_FOUND, str(e)) | ||
| return dispatch_pb2.CallablePayload( | ||
| callable_id=request.callable_id, | ||
| version=request.version, |
There was a problem hiding this comment.
version=0 is treated as "latest" by export_payload(), but this response returns the literal request version. That can return valid payload bytes with version=0, and install_from_payload() validates version == hash(payload), so a caller installing the response can fail. Please return the resolved payload version, or remove latest-version semantics from this RPC.
| def l4_orch(orch, task_args, config): | ||
| for value in (2, 5): | ||
| sub_args = TaskArgs() | ||
| sub_args.add_scalar(value) |
There was a problem hiding this comment.
This example calls add_scalar() before add_tensor(), but TaskArgs rejects adding tensors after scalars. Running the documented example fails with RuntimeError: TaskArgs: cannot add tensor after scalar. Please add tensors before scalars, or change the TaskArgs contract/tests if interleaving is intended.
| uint64 callable_version = 3; | ||
| bytes config_blob = 4; | ||
| repeated uint64 scalar_args = 5; | ||
| repeated ContinuousTensorRef tensor_args = 6; |
There was a problem hiding this comment.
tensor_args carries ContinuousTensorRef.data, an address-like value from the sender process. That is not a valid execution pointer on remote L3. Since tensor_refs is the real remote data-plane schema, please clarify whether tensor_args remains supported; if not, reject non-empty tensor_args in the daemon or remove it from the active dispatch path.
| local_output_regions.append(region) | ||
| continue | ||
| data = ctypes.string_at(int(tensor.data), nbytes) if nbytes else b"" | ||
| if nbytes <= self._tensor_inline_threshold: |
There was a problem hiding this comment.
This makes inline-vs-TensorPool staging an implicit byte-threshold decision. I think the intended L4 programming model should be explicit: the L4 program allocates/registers remote tensor storage, gets a TensorHandle/remote ref, and passes that as the tensor argument. Inline bytes can remain as an explicit small-message/test path, but the normal distributed tensor path should not silently switch based only on size.
| refs.append(ref) | ||
| local_output_regions.append(region) | ||
| continue | ||
| data = ctypes.string_at(int(tensor.data), nbytes) if nbytes else b"" |
There was a problem hiding this comment.
This reads bytes for every tensor, including OUTPUT/OUTPUT_EXISTING unless the large RXE output path is selected. Old output-buffer contents are not semantic inputs, so small output tensors and fallback output paths send irrelevant bytes to L3. Please separate input staging from output allocation/writeback.
| def _should_stage_local_output(self, tag, nbytes: int) -> bool: # noqa: ANN001 | ||
| return ( | ||
| self._tensor_transport in {"rxe", "auto"} | ||
| and getattr(tag, "name", "") in {"OUTPUT", "OUTPUT_EXISTING"} |
There was a problem hiding this comment.
The local-output RXE fast path excludes INOUT, so large INOUT is copy-in/copy-out rather than bidirectional RXE. That is acceptable as an MVP if documented, but it should not be described as shared or in-place remote memory.
| else: | ||
| args = decode_task_args(req.tensor_args, req.scalar_args) | ||
| if req.tensor_refs: | ||
| run_inner = worker_factory() |
There was a problem hiding this comment.
Tensor-ref dispatch creates a fresh backend Worker(level=3) per request, initializes it, runs it, and closes it. This is likely needed so newly materialized mmap buffers exist before L3 children fork, but it is a major lifecycle/performance difference from a persistent L3 worker. Please document this tradeoff and the plan for persistent worker + child-visible tensor storage.
| def _backend_call(self, msg): | ||
| if self._backend_conn is None: | ||
| raise RuntimeError("L3 daemon backend is not running") | ||
| with self._backend_lock: |
There was a problem hiding this comment.
Every foreground RPC is serialized through _backend_lock and a blocking Pipe send/recv before backend execution. If this process split is required for fork-safety with gRPC threads, please document it and provide overhead numbers; otherwise consider a direct backend RPC/event loop.
| ) | ||
|
|
||
| def Heartbeat(self, request, context): # noqa: N802, ANN001 | ||
| return dispatch_pb2.Health(ok=True, message="ok") |
There was a problem hiding this comment.
This heartbeat only proves foreground gRPC liveness. It does not check backend process health, worker initialization, runtime/device readiness, TensorPool capacity, or selected transport. Please either rename/document it as liveness only, or add a deeper readiness RPC.
| self._allowed_modules = allowed_modules | ||
|
|
||
| def register(self, fn: Callable, callable_id: Optional[int] = None) -> tuple[int, int]: | ||
| payload = _pickle_impl.dumps(fn) |
There was a problem hiding this comment.
Remote callables are serialized with cloudpickle/pickle, which is not equivalent to local fork/COW callable inheritance. Captured mutable state, raw pointers, file descriptors, sockets, locks, device contexts, imports, and side effects expected to be visible at L4 can behave differently. Please document the remote callable contract and the trusted-cluster assumption.
uv-xiao
left a comment
There was a problem hiding this comment.
Second review round: callable semantics and the L4/L3 remote example.
The main issue is that the public API is ABI-uniform but not semantically uniform. submit_next_level(callable, ...) accepts a uint64, but the semantic meaning changes by level/target: L3-to-L2 expects a chip callable handle, L4-to-L3 expects a Python orchestration callable id, and remote L3 treats the value as a catalog id. Worker.register(fn) also stores both SubWorker callables (fn(task_args)) and orchestration callables (fn(orch, task_args, config)) in one untyped namespace. The current example makes this hard to understand because w4.register(l3_sub) registers a callable intended for a remote L3 SubWorker, while w4.register(l3_orch) registers the remote L3 orchestration function.
I think the minimum direction should be typed callable registration/handles, for example register_sub(...), register_orch(...), and a distinct chip callable handle. The internal slot can still carry a compact integer, but the public API should prevent passing a subworker callable id where a next-level orchestration callable is expected.
| endpoints = [item.strip() for item in args.remotes.split(",") if item.strip()] | ||
|
|
||
| def l3_sub(task_args): | ||
| output = task_args.tensor(1) |
There was a problem hiding this comment.
task_args contains one tensor in this example, added below at line 36, so this should be task_args.tensor(0). As written, the subworker tries to read tensor index 1 and the example cannot demonstrate the intended remote L3 writeback behavior.
| sub_cid = w4.register(l3_sub) | ||
|
|
||
| def l3_orch(orch, task_args, config): | ||
| orch.submit_sub(sub_cid, task_args) |
There was a problem hiding this comment.
This makes the example hard to interpret. submit_sub() targets an L3 SubWorker, not the next-level remote L3 worker itself. The actual path is L4 submit_next_level(l3_cid) -> remote L3 runs l3_orch -> remote L3 submit_sub(sub_cid) -> SubWorker mutates the tensor. If this example is meant to explain L4-to-remote-L3 dispatch, please either do the mutation directly in l3_orch, or make the example explicitly about "L4 dispatch to remote L3, then remote L3 dispatch to its SubWorker".
| current.value += int(task_args.scalar(0)) | ||
|
|
||
| w4 = Worker(level=4, num_sub_workers=0) | ||
| sub_cid = w4.register(l3_sub) |
There was a problem hiding this comment.
This is semantically confusing: l3_sub is not a W4 SubWorker callable because w4 has num_sub_workers=0; it is intended to be shipped through the L4 catalog and later used by the remote L3 daemon as an L3 SubWorker callable. That works only because Worker.register() is an untyped global callable-id table. Please make the callable kind explicit, for example register_sub(...) vs register_orch(...), or otherwise document and validate that this id is meant for the remote child's SUB namespace.
| def l4_orch(orch, task_args, config): | ||
| for value in (2, 5): | ||
| sub_args = TaskArgs() | ||
| sub_args.add_scalar(value) |
There was a problem hiding this comment.
This adds a scalar before adding a tensor, but TaskArgs requires tensors before scalars. The example should add the output tensor first and then the scalar. This concrete bug also obscures the higher-level callable-semantics issue because the example can fail before reaching the remote dispatch path.
| raise RuntimeError("Worker.register() must be called before init()") | ||
| cid = len(self._callable_registry) | ||
| self._callable_registry[cid] = fn | ||
| if self._distributed_catalog is not None: |
There was a problem hiding this comment.
register() stores every Python callable in one _callable_registry, but the callable contracts are different depending on how the id is later used. submit_sub(cid, ...) expects fn(task_args), while L4 submit_next_level(cid, ...) expects a child orchestration callable shaped like fn(orch, task_args, config). The registry should not be untyped here. Please split this into explicit APIs such as register_sub() and register_orch(), or return typed handles so the wrong callable kind cannot be submitted to the wrong worker type.
| from .distributed.catalog import Catalog # noqa: PLC0415 | ||
|
|
||
| self._distributed_catalog = Catalog() | ||
| for cid, fn in self._callable_registry.items(): |
There was a problem hiding this comment.
The distributed catalog mirrors every entry from _callable_registry, but it does not preserve whether an entry is a SubWorker callable or a next-level orchestration callable. That is why the L4 example can register both l3_sub and l3_orch on w4 and rely on later submit paths to decide their meaning. For remote dispatch this should be explicit in the catalog payload or handle type; otherwise a wrong id can cross the network and fail only later as a signature/runtime mismatch.
| try: | ||
| args = _read_args_from_mailbox(buf) | ||
| cfg = _read_config_from_mailbox(buf) | ||
| proxy.dispatch(int(cid), args, cfg) |
There was a problem hiding this comment.
This forwards the mailbox callable field as a remote catalog id. That is the core semantic mismatch: the same submit_next_level(callable, ...) slot can mean a chip callable handle for L3-to-L2, but a Python orchestration/catalog id for L4-to-L3. The public API should expose a typed next-level orchestration handle here, even if the mailbox representation remains a uint64 internally.
Summary
Tests
Notes